/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#include <thrift/thrift-config.h>

#include <thrift/transport/TSSLServerSocket.h>
#include <thrift/transport/TSSLSocket.h>
#include <thrift/concurrency/IOService.h>

namespace apache {
namespace thrift {
namespace transport {

using namespace std;
using boost::shared_ptr;
using namespace apache::thrift::concurrency;
using boost::asio::deadline_timer;
using boost::asio::ip::tcp;

/**
 * SSL server socket implementation.
 */
TSSLServerSocket::TSSLServerSocket(int port, const boost::shared_ptr<TSSLSocketFactory>& factory)
  : factory_(factory)
  , port_(port)
  , acceptor_(IOService::getDefaultIOService())
  , deadline_(IOService::getDefaultIOService())
  , acceptBacklog_(DEFAULT_BACKLOG)
  , sendTimeout_(0)
  , recvTimeout_(0)
  , accTimeout_(-1)
  , retryLimit_(0)
  , retryDelay_(0)
  , tcpSendBuffer_(0)
  , tcpRecvBuffer_(0)
  , keepAlive_(false)
{
  factory_->server(true);
}

TSSLServerSocket::TSSLServerSocket(int port,
                                   int sendTimeout,
                                   int recvTimeout,
                                   const boost::shared_ptr<TSSLSocketFactory>& factory)
  : factory_(factory)
  , port_(port)
  , acceptor_(IOService::getDefaultIOService())
  , deadline_(IOService::getDefaultIOService())
  , acceptBacklog_(DEFAULT_BACKLOG)
  , sendTimeout_(sendTimeout)
  , recvTimeout_(recvTimeout)
  , accTimeout_(-1)
  , retryLimit_(0)
  , retryDelay_(0)
  , tcpSendBuffer_(0)
  , tcpRecvBuffer_(0)
  , keepAlive_(false)
{
  factory_->server(true);
}

TSSLServerSocket::~TSSLServerSocket()
{
  close();
}

boost::shared_ptr<TSSLSocket> TSSLServerSocket::createSocket()
{
  return factory_->createSocket();
}

void TSSLServerSocket::setSendTimeout(int sendTimeout)
{
  sendTimeout_ = sendTimeout;
}

void TSSLServerSocket::setRecvTimeout(int recvTimeout)
{
  recvTimeout_ = recvTimeout;
}

void TSSLServerSocket::setAcceptTimeout(int accTimeout)
{
  accTimeout_ = accTimeout;
}

void TSSLServerSocket::setAcceptBacklog(int accBacklog)
{
  acceptBacklog_ = accBacklog;
}

void TSSLServerSocket::setRetryLimit(int retryLimit)
{
  retryLimit_ = retryLimit;
}

void TSSLServerSocket::setRetryDelay(int retryDelay)
{
  retryDelay_ = retryDelay;
}

void TSSLServerSocket::setTcpSendBuffer(int tcpSendBuffer)
{
  tcpSendBuffer_ = tcpSendBuffer;
}

void TSSLServerSocket::setTcpRecvBuffer(int tcpRecvBuffer)
{
  tcpRecvBuffer_ = tcpRecvBuffer;
}

void TSSLServerSocket::setKeepAlive(bool keepAlive)
{
  keepAlive_ = keepAlive;
}

void TSSLServerSocket::listen()
{
  using namespace boost::asio;
  ip::tcp::endpoint endpoint(ip::tcp::v4(), port_);
  acceptor_.open(endpoint.protocol());
  boost::system::error_code ec;

  acceptor_.set_option(ip::tcp::acceptor::reuse_address(true), ec);
  if (ec) {
    throw TTransportException(TTransportException::NOT_OPEN,
      "Could not set TCP_DEFER_ACCEPT",
      ec.value());
  }

  // Set TCP buffer sizes
  if (tcpSendBuffer_ > 0) {
    boost::system::error_code ec;
    boost::asio::socket_base::send_buffer_size option(tcpSendBuffer_);
    acceptor_.set_option(option, ec);
    if (ec) {
      throw TTransportException(TTransportException::NOT_OPEN,
        "Could not set SO_SNDBUF",
        ec.value());
    }
  }

  if (tcpRecvBuffer_ > 0) {
    boost::system::error_code ec;
    boost::asio::socket_base::receive_buffer_size option(tcpRecvBuffer_);
    acceptor_.set_option(option, ec);
    if (ec) {
      throw TTransportException(TTransportException::NOT_OPEN,
        "Could not set SO_RCVBUF",
        ec.value());
    }
  }

  // Turn linger off, don't want to block on calls to close
  boost::asio::socket_base::linger lingeroption(false, 0);
  acceptor_.set_option(lingeroption, ec);

  if (ec) {
    throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_LINGER", ec.value());
  }

  {
    // TCP Nodelay, speed over bandwidth
    boost::system::error_code ec;
    boost::asio::ip::tcp::no_delay option(true);
    acceptor_.set_option(option, ec);
    if (ec) {
      throw TTransportException(TTransportException::NOT_OPEN,
        "Could not set TCP_NODELAY",
        ec.value());
    }
  }

  acceptor_.native_non_blocking(true, ec);
  if (ec) {
    throw TTransportException(TTransportException::NOT_OPEN, "native_non_blocking() failed", ec.value());
  }

  // prepare the port information
  // we may want to try to bind more than once, since THRIFT_NO_SOCKET_CACHING doesn't
  // always seem to work. The client can configure the retry variables.
  int retries = 0;

  {
    acceptor_.bind(endpoint, ec);
    if (ec) {
      throw TTransportException(TTransportException::NOT_OPEN,
        "Could not bind",
        ec.value());
    }
    if (port_ == 0) {
      port_ = acceptor_.local_endpoint().port();
    }
  }


  acceptor_.listen(acceptBacklog_, ec);
  if (ec) {
    throw TTransportException(TTransportException::NOT_OPEN,
      "Could not listen",
      ec.value());
  }

  // The socket is now listening!
}

int TSSLServerSocket::getPort()
{
  return port_;
}

struct ssl_accept_handler
{
  ssl_accept_handler(const boost::shared_ptr<boost::system::error_code>& ec,
                     const boost::shared_ptr<boost::mutex>& mux,
                     const boost::shared_ptr<boost::condition_variable>& cnd)
      : error(ec)
      , tmux(mux)
      , cond(cnd)
  {

  }

  void operator () (const boost::system::error_code& ec)
  {
    boost::unique_lock<boost::mutex> lock(*tmux);
    *error = ec;
    lock.unlock();
    cond->notify_one();
  }

  boost::shared_ptr<boost::system::error_code> error;
  boost::shared_ptr<boost::mutex> tmux;
  boost::shared_ptr<boost::condition_variable> cond;
};

shared_ptr<TTransport> TSSLServerSocket::acceptImpl()
{
  if (acceptor_.is_open() == false) {
    throw TTransportException(TTransportException::NOT_OPEN, "TServerSocket not listening");
  }

  shared_ptr<TSSLSocket> client = createSocket();

  boost::shared_ptr<boost::mutex> mux(new boost::mutex);
  boost::shared_ptr<boost::condition_variable> cond(new boost::condition_variable);
  boost::shared_ptr<boost::system::error_code> ec(new boost::system::error_code(boost::asio::error::would_block));

  ssl_accept_handler acc_handler(ec, mux,  cond);

  acceptor_.async_accept(client->getSocket(), acc_handler);

  if (accTimeout_ > 0) {
    deadline_.expires_from_now(boost::posix_time::milliseconds(accTimeout_));
    deadline_.async_wait(acc_handler);
  }

  boost::unique_lock<boost::mutex> lock(*mux);
  while (*ec == boost::asio::error::would_block) {
    cond->wait(lock);
  }

  if (*ec) {
    throw TTransportException(TTransportException::UNKNOWN, "accept()", ec->value());
  }

  deadline_.cancel();

  if (sendTimeout_ > 0) {
    client->setSendTimeout(sendTimeout_);
  }
  if (recvTimeout_ > 0) {
    client->setRecvTimeout(recvTimeout_);
  }
  if (keepAlive_) {
    client->setKeepAlive(keepAlive_);
  }

  client->open();

  return client;
}

void TSSLServerSocket::interrupt()
{
  if (acceptor_.is_open()) {
    acceptor_.cancel();
  }
}

void TSSLServerSocket::close()
{
  if (acceptor_.is_open()) {
    boost::system::error_code ec;
    acceptor_.close(ec);
  }
}

}
}
}
